"""Stomp message sending."""
import contextlib
import enum
import json
import os
import threading
import typing
import uuid

import stomp

from . import certs
from . import logger
from . import misc

LOGGER = logger.get_logger(__name__)


class SendListener(stomp.listener.ConnectionListener):
    """Wait for a response to a SEND frame."""

    Status = enum.Enum('Status', ['NONE', 'RECEIPT', 'ERROR', 'DISCONNECT'])

    def __init__(self, receipt):
        """Create a listener."""
        self.receipt = receipt
        self.condition = threading.Condition()
        self.status = self.Status.NONE

    def _notify(self, status: Status) -> None:
        LOGGER.debug('received %s', status)
        with self.condition:
            self.status = status
            self.condition.notify()

    def on_receipt(self, frame):
        """Notify about a receipt."""
        if frame.headers.get('receipt-id') == self.receipt:
            self._notify(self.Status.RECEIPT)

    def on_error(self, frame):
        """Notify about an error."""
        self._notify(self.Status.ERROR)

    def on_disconnected(self):
        """Notify about a disconnect."""
        self._notify(self.Status.DISCONNECT)

    def wait(self):
        """Wait for a response to a SEND frame."""
        with self.condition:
            while self.status == self.Status.NONE:
                self.condition.wait()
            if self.status != self.Status.RECEIPT:
                raise Exception(f'Received {self.status} response from server')


class StompClient:
    """Class for handling UMB communication."""

    def __init__(
        self,
        *,
        host: typing.Union[None, str, list[str]] = None,
        port: typing.Optional[int] = None,
        certfile: typing.Optional[str] = None,
    ) -> None:
        """Initialize a client."""
        if not host:
            host = os.environ.get('STOMP_HOST', 'localhost')
        if isinstance(host, str):
            host = host.split()
        port = int(port or misc.get_env_int('STOMP_PORT', 61612))

        self.brokers = [(h, port) for h in host]
        self.certfile = certfile or os.environ.get('STOMP_CERTFILE')

    @contextlib.contextmanager
    def connect(self) -> stomp.Connection:
        """Connect to the server and yield a connection."""
        connection = stomp.Connection(self.brokers, keepalive=True)
        if self.certfile:
            connection.set_ssl(self.brokers,
                               key_file=self.certfile,
                               cert_file=self.certfile)
            certs.update_certificate_metrics(self.certfile)
        connection.connect(wait=True)
        try:
            yield connection
        finally:
            # stomp processing is async, and this will swallow any errors
            # because of a missing disconnect receipt from the server which
            # also acks that any previous messages have been received; the
            # explicit send() receipt below covers reception and processing of
            # any sent message, so checking the disconnect receipt should not
            # be needed
            # https://stomp.github.io/stomp-specification-1.2.html#RECEIPT
            with misc.only_log_exceptions():
                connection.disconnect()

    def send_message(self, data: typing.Any, queue_name: str) -> None:
        """Send message to topic.

        Encode `data` as json and send it to topic=queue_name.
        """
        logging_env = {
            'send_message': {
                'body': data,
                'routing_key': queue_name,
            }
        }
        with logger.logging_env(logging_env):
            LOGGER.info('Sending message with routing_key=%s', queue_name)
            with self.connect() as connection:
                receipt_id = str(uuid.uuid4())
                listener = SendListener(receipt_id)
                connection.set_listener(receipt_id, listener)
                connection.send(queue_name, json.dumps(data), receipt=receipt_id)
                listener.wait()
